import torch
import torch.nn as nn
from torch.nn import init
import functools
from torchvision import models
import torch.nn.functional as F
from torch.optim import lr_scheduler
import math
import utils
import matplotlib.pyplot as plt
import numpy as np

# Decide which device we want to run on
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

PI = math.pi
###############################################################################
# Helper Functions
###############################################################################


class Identity(nn.Module):
	def forward(self, x):
		return x

def get_norm_layer(norm_type='instance'):
	"""Return a normalization layer

	Parameters:
		norm_type (str) -- the name of the normalization layer: batch | instance | none

	For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
	For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
	"""
	if norm_type == 'batch':
		norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
	elif norm_type == 'instance':
		norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
	elif norm_type == 'none':
		norm_layer = lambda x: Identity()
	else:
		raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
	return norm_layer


def init_weights(net, init_type='normal', init_gain=0.02):
	"""Initialize network weights.

	Parameters:
		net (network)   -- network to be initialized
		init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
		init_gain (float)    -- scaling factor for normal, xavier and orthogonal.

	We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
	work better for some applications. Feel free to try yourself.
	"""
	def init_func(m):  # define the initialization function
		classname = m.__class__.__name__
		if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
			if init_type == 'normal':
				init.normal_(m.weight.data, 0.0, init_gain)
			elif init_type == 'xavier':
				init.xavier_normal_(m.weight.data, gain=init_gain)
			elif init_type == 'kaiming':
				init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
			elif init_type == 'orthogonal':
				init.orthogonal_(m.weight.data, gain=init_gain)
			else:
				raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
			if hasattr(m, 'bias') and m.bias is not None:
				init.constant_(m.bias.data, 0.0)
		elif classname.find('BatchNorm2d') != -1:  # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
			init.normal_(m.weight.data, 1.0, init_gain)
			init.constant_(m.bias.data, 0.0)

	print('initialize network with %s' % init_type)
	net.apply(init_func)  # apply the initialization function <init_func>


def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
	"""Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
	Parameters:
		net (network)      -- the network to be initialized
		init_type (str)    -- the name of an initialization method: normal | xavier | kaiming | orthogonal
		gain (float)       -- scaling factor for normal, xavier and orthogonal.
		gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2

	Return an initialized network.
	"""
	if len(gpu_ids) > 0:
		assert(torch.cuda.is_available())
		net.to(gpu_ids[0])
		net = torch.nn.DataParallel(net, gpu_ids)  # multi-GPUs
	init_weights(net, init_type, init_gain=init_gain)
	return net


def define_G(rdrr, netG, init_type='normal', init_gain=0.02, gpu_ids=[]):
	net = None
	if netG == 'plain-dcgan':
		net = DCGAN(rdrr)
	elif netG == 'plain-unet':
		net = UNet(rdrr)
	elif netG == 'huang-net':
		net = HuangNet(rdrr)
	elif netG == 'zou-fusion-net':
		net = ZouFCNFusion(rdrr)
	elif netG == 'zou-fusion-net-light':
		net = ZouFCNFusionLight(rdrr)
	else:
		raise NotImplementedError('Generator model name [%s] is not recognized' % netG)
	return init_net(net, init_type, init_gain, gpu_ids)


class DCGAN(nn.Module):
	def __init__(self, rdrr, ngf=64):
		super(DCGAN, self).__init__()
		input_nc = rdrr.d
		self.out_size = 128
		self.main = nn.Sequential(
			# input is Z, going into a convolution
			nn.ConvTranspose2d(input_nc, ngf * 8, 4, 1, 0, bias=False),
			nn.BatchNorm2d(ngf * 8),
			nn.ReLU(True),
			# state size. (ngf*8) x 4 x 4

			nn.ConvTranspose2d(ngf * 8, ngf * 8, 4, 2, 1, bias=False),
			nn.BatchNorm2d(ngf * 8),
			nn.ReLU(True),
			# state size. (ngf*4) x 8 x 8

			nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
			nn.BatchNorm2d(ngf * 4),
			nn.ReLU(True),
			# state size. (ngf*2) x 16 x 16

			nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
			nn.BatchNorm2d(ngf * 2),
			nn.ReLU(True),
			# state size. (ngf*2) x 32 x 32

			nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
			nn.BatchNorm2d(ngf),
			nn.ReLU(True),
			# state size. (ngf*2) x 64 x 64

			nn.ConvTranspose2d(ngf, 6, 4, 2, 1, bias=False),
			# state size. (nc) x 128 x 128
		)

	def forward(self, input):
		output_tensor = self.main(input)
		return output_tensor[:,0:3,:,:], output_tensor[:,3:6,:,:]



class DCGAN_32(nn.Module):
	def __init__(self, rdrr, ngf=64):
		super(DCGAN_32, self).__init__()
		input_nc = rdrr.d
		self.out_size = 32
		self.main = nn.Sequential(
			# input is Z, going into a convolution
			nn.ConvTranspose2d(input_nc, ngf * 8, 4, 1, 0, bias=False),
			nn.BatchNorm2d(ngf * 8),
			nn.ReLU(True),
			# state size. (ngf*8) x 4 x 4

			nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
			nn.BatchNorm2d(ngf * 4),
			nn.ReLU(True),
			# state size. (ngf*4) x 8 x 8

			nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
			nn.BatchNorm2d(ngf * 2),
			nn.ReLU(True),
			# state size. (ngf*2) x 16 x 16

			nn.ConvTranspose2d(ngf * 2, 6, 4, 2, 1, bias=False),
			# state size. 6 x 32 x 32
		)

	def forward(self, input):
		output_tensor = self.main(input)
		return output_tensor[:,0:3,:,:], output_tensor[:,3:6,:,:]



class PixelShuffleNet(nn.Module):
	def __init__(self, input_nc):
		super(PixelShuffleNet, self).__init__()
		self.fc1 = (nn.Linear(input_nc, 512))
		self.fc2 = (nn.Linear(512, 1024))
		self.fc3 = (nn.Linear(1024, 2048))
		self.fc4 = (nn.Linear(2048, 4096))
		self.conv1 = (nn.Conv2d(16, 32, 3, 1, 1))
		self.conv2 = (nn.Conv2d(32, 32, 3, 1, 1))
		self.conv3 = (nn.Conv2d(8, 16, 3, 1, 1))
		self.conv4 = (nn.Conv2d(16, 16, 3, 1, 1))
		self.conv5 = (nn.Conv2d(4, 8, 3, 1, 1))
		self.conv6 = (nn.Conv2d(8, 4*3, 3, 1, 1))
		self.pixel_shuffle = nn.PixelShuffle(2)

	def forward(self, x):
		x = x.squeeze()
		x = F.relu(self.fc1(x))
		x = F.relu(self.fc2(x))
		x = F.relu(self.fc3(x))
		x = F.relu(self.fc4(x))
		x = x.view(-1, 16, 16, 16)
		x = F.relu(self.conv1(x))
		x = self.pixel_shuffle(self.conv2(x))
		x = F.relu(self.conv3(x))
		x = self.pixel_shuffle(self.conv4(x))
		x = F.relu(self.conv5(x))
		x = self.pixel_shuffle(self.conv6(x))
		x = x.view(-1, 3, 128, 128)
		return x



class PixelShuffleNet_32(nn.Module):
	def __init__(self, input_nc):
		super(PixelShuffleNet_32, self).__init__()
		self.fc1 = (nn.Linear(input_nc, 512))
		self.fc2 = (nn.Linear(512, 1024))
		self.fc3 = (nn.Linear(1024, 2048))
		self.conv1 = (nn.Conv2d(8, 64, 3, 1, 1))
		self.conv2 = (nn.Conv2d(64, 4*3, 3, 1, 1))
		self.pixel_shuffle = nn.PixelShuffle(2)

	def forward(self, x):
		x = x.squeeze()
		x = F.relu(self.fc1(x))
		x = F.relu(self.fc2(x))
		x = F.relu(self.fc3(x))
		x = x.view(-1, 8, 16, 16)
		x = F.relu(self.conv1(x))
		x = self.pixel_shuffle(self.conv2(x))
		x = x.view(-1, 3, 32, 32)
		return x




class HuangNet(nn.Module):
	def __init__(self, rdrr):
		super(HuangNet, self).__init__()
		self.rdrr = rdrr
		self.out_size = 128
		self.fc1 = (nn.Linear(rdrr.d, 512))
		self.fc2 = (nn.Linear(512, 1024))
		self.fc3 = (nn.Linear(1024, 2048))
		self.fc4 = (nn.Linear(2048, 4096))
		self.conv1 = (nn.Conv2d(16, 32, 3, 1, 1))
		self.conv2 = (nn.Conv2d(32, 32, 3, 1, 1))
		self.conv3 = (nn.Conv2d(8, 16, 3, 1, 1))
		self.conv4 = (nn.Conv2d(16, 16, 3, 1, 1))
		self.conv5 = (nn.Conv2d(4, 8, 3, 1, 1))
		self.conv6 = (nn.Conv2d(8, 4 * 6, 3, 1, 1))
		self.pixel_shuffle = nn.PixelShuffle(2)


	def forward(self, x):
		x = x.squeeze()
		x = F.relu(self.fc1(x))
		x = F.relu(self.fc2(x))
		x = F.relu(self.fc3(x))
		x = F.relu(self.fc4(x))
		x = x.view(-1, 16, 16, 16)
		x = F.relu(self.conv1(x))
		x = self.pixel_shuffle(self.conv2(x))
		x = F.relu(self.conv3(x))
		x = self.pixel_shuffle(self.conv4(x))
		x = F.relu(self.conv5(x))
		x = self.pixel_shuffle(self.conv6(x))
		output_tensor = x.view(-1, 6, 128, 128)
		return output_tensor[:,0:3,:,:], output_tensor[:,3:6,:,:]



class ZouFCNFusion(nn.Module):
	def __init__(self, rdrr):
		super(ZouFCNFusion, self).__init__()
		self.rdrr = rdrr
		self.out_size = 128
		self.huangnet = PixelShuffleNet(rdrr.d_shape)
		self.dcgan = DCGAN(rdrr)

	def forward(self, x):
		x_shape = x[:, 0:self.rdrr.d_shape, :, :]
		x_alpha = x[:, [-1], :, :]
		if self.rdrr.renderer in ['oilpaintbrush', 'airbrush']:
			x_alpha = torch.tensor(1.0).to(device)

		mask = self.huangnet(x_shape)
		color, _ = self.dcgan(x)

		return color * mask, x_alpha * mask



class ZouFCNFusionLight(nn.Module):
	def __init__(self, rdrr):
		super(ZouFCNFusionLight, self).__init__()
		self.rdrr = rdrr
		self.out_size = 32
		self.huangnet = PixelShuffleNet_32(rdrr.d_shape)
		self.dcgan = DCGAN_32(rdrr)

	def forward(self, x):
		x_shape = x[:, 0:self.rdrr.d_shape, :, :]
		x_alpha = x[:, [-1], :, :]
		if self.rdrr.renderer in ['oilpaintbrush', 'airbrush']:
			x_alpha = torch.tensor(1.0).to(device)

		mask = self.huangnet(x_shape)
		color, _ = self.dcgan(x)

		return color * mask, x_alpha * mask




class UNet(torch.nn.Module):
	def __init__(self, rdrr):
		"""
		In the constructor we instantiate two nn.Linear modules and assign them as
		member variables.
		"""
		super(UNet, self).__init__()
		norm_layer = get_norm_layer(norm_type='batch')
		self.unet = UnetGenerator(rdrr.d, 6, 7, norm_layer=norm_layer, use_dropout=False)

	def forward(self, x):
		"""
		In the forward function we accept a Tensor of input data and we must return
		a Tensor of output data. We can use Modules defined in the constructor as
		well as arbitrary operators on Tensors.
		"""
		# resnet layers
		x = x.repeat(1, 1, 128, 128)
		output_tensor = self.unet(x)
		return output_tensor[:,0:3,:,:], output_tensor[:,3:6,:,:]



class UnetGenerator(nn.Module):
	"""Create a Unet-based generator"""

	def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False):
		"""Construct a Unet generator
		Parameters:
			input_nc (int)  -- the number of channels in input images
			output_nc (int) -- the number of channels in output images
			num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
								image of size 128x128 will become of size 1x1 # at the bottleneck
			ngf (int)       -- the number of filters in the last conv layer
			norm_layer      -- normalization layer

		We construct the U-Net from the innermost layer to the outermost layer.
		It is a recursive process.
		"""
		super(UnetGenerator, self).__init__()
		# construct unet structure
		unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True)  # add the innermost layer
		for i in range(num_downs - 5):          # add intermediate layers with ngf * 8 filters
			unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
		# gradually reduce the number of filters from ngf * 8 to ngf
		unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
		unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
		unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
		self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)  # add the outermost layer

	def forward(self, input):
		"""Standard forward"""
		return self.model(input)


class UnetSkipConnectionBlock(nn.Module):
	"""Defines the Unet submodule with skip connection.
		X -------------------identity----------------------
		|-- downsampling -- |submodule| -- upsampling --|
	"""

	def __init__(self, outer_nc, inner_nc, input_nc=None,
				 submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
		"""Construct a Unet submodule with skip connections.

		Parameters:
			outer_nc (int) -- the number of filters in the outer conv layer
			inner_nc (int) -- the number of filters in the inner conv layer
			input_nc (int) -- the number of channels in input images/features
			submodule (UnetSkipConnectionBlock) -- previously defined submodules
			outermost (bool)    -- if this module is the outermost module
			innermost (bool)    -- if this module is the innermost module
			norm_layer          -- normalization layer
			user_dropout (bool) -- if use dropout layers.
		"""
		super(UnetSkipConnectionBlock, self).__init__()
		self.outermost = outermost
		if type(norm_layer) == functools.partial:
			use_bias = norm_layer.func == nn.InstanceNorm2d
		else:
			use_bias = norm_layer == nn.InstanceNorm2d
		if input_nc is None:
			input_nc = outer_nc
		downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
							 stride=2, padding=1, bias=use_bias)
		downrelu = nn.LeakyReLU(0.2, True)
		downnorm = norm_layer(inner_nc)
		uprelu = nn.ReLU(True)
		upnorm = norm_layer(outer_nc)

		if outermost:
			upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
										kernel_size=4, stride=2,
										padding=1)
			down = [downconv]
			# up = [uprelu, upconv, nn.Tanh()]
			# up = [uprelu, upconv, nn.Sigmoid()] # ZZX
			up = [uprelu, upconv]  # ZZX
			model = down + [submodule] + up
		elif innermost:
			upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
										kernel_size=4, stride=2,
										padding=1, bias=use_bias)
			down = [downrelu, downconv]
			up = [uprelu, upconv, upnorm]
			model = down + up
		else:
			upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
										kernel_size=4, stride=2,
										padding=1, bias=use_bias)
			down = [downrelu, downconv, downnorm]
			up = [uprelu, upconv, upnorm]

			if use_dropout:
				model = down + [submodule] + up + [nn.Dropout(0.5)]
			else:
				model = down + [submodule] + up

		self.model = nn.Sequential(*model)

	def forward(self, x):
		if self.outermost:
			return self.model(x)
		else:   # add skip connections
			return torch.cat([x, self.model(x)], 1)
